package login

import (
	"gosh/config"
	"gosh/pkg/model"
	"io/ioutil"
	"os"
	"path"
	"regexp"
	"strings"
	"time"
)

var LoginConfig = &Config{}

type Config struct {
	Host                 string
	Port                 int
	User                 string
	Password             string
	Key                  string
	KeyContent           string
	KeyPassphrase        string
	BastionHost          string
	BastionPort          int
	BastionUser          string
	BastionPassword      string
	BastionKey           string
	BastionKeyContent    string
	BastionKeyPassphrase string
	ConnectTimeout       time.Duration
	Labels               []string
	Condition            map[string]string
}

func (c *Config) labelToCondition() {
	for _, label := range c.Labels {
		match, _ := regexp.Match(".+=.+", []byte(label))
		if match {
			split := strings.Split(label, "=")
			c.Condition[split[0]] = split[1]
		}
	}
}

func (c *Config) Parse() error {
	c.Condition = make(map[string]string)
	c.Condition["connect_ip"] = c.Host
	c.labelToCondition()

	host, bastion, err := c.getHostFromDB()
	if err != nil {
		return err
	}

	// 获取运行用户信息
	if c.User == "" {
		if host != nil && host.SSHUser != "" {
			c.User = host.SSHUser
		} else {
			c.User = os.Getenv("USER")
		}
	}

	if c.Port == 0 {
		if host != nil && host.SSHPort > 0 {
			c.Port = host.SSHPort
		} else {
			c.Port = 22
		}
	}

	if c.Password == "" && host != nil && host.SSHPassword != "" {
		c.Password = host.SSHPassword
	}

	if c.Key == "" {
		if host != nil && host.SSHKeyFile != "" {
			c.Key = replaceHomeDir(host.SSHKeyFile)
		} else {
			c.Key = path.Join(os.Getenv("HOME"), ".ssh/id_rsa")
		}
	}

	if c.KeyPassphrase == "" && host != nil && host.SSHKeyPassphrase != "" {
		c.KeyPassphrase = host.SSHKeyPassphrase
	}

	hostKey, err := readSSHKey(c.Key)
	if err != nil {
		return err
	}
	c.KeyContent = hostKey

	// 处理跳板机逻辑
	if c.BastionHost == "" && bastion != nil {
		c.BastionHost = bastion.ConnectIP
	}
	// 如果数据库和命令行都没有跳板机，则表示直连
	if c.BastionHost == "" {
		return nil
	}

	if c.BastionPort == 0 {
		if bastion != nil && bastion.SSHPort > 0 {
			c.BastionPort = bastion.SSHPort
		} else {
			c.BastionPort = 22
		}
	}
	if c.BastionUser == "" {
		if bastion != nil && bastion.SSHUser != "" {
			c.BastionUser = bastion.SSHUser
		} else {
			c.BastionUser = os.Getenv("USER")
		}
	}

	if c.BastionPassword == "" && bastion != nil {
		c.BastionPassword = bastion.SSHPassword
	}
	if c.BastionKey == "" {
		if bastion != nil && bastion.SSHKeyFile != "" {
			c.BastionKey = replaceHomeDir(bastion.SSHKeyFile)
		} else {
			c.BastionKey = path.Join(os.Getenv("HOME"), ".ssh/id_rsa")
		}
	}
	if c.BastionKeyPassphrase == "" && bastion != nil {
		c.BastionKeyPassphrase = bastion.SSHKeyPassphrase
	}

	bastionKey, err := readSSHKey(c.BastionKey)
	if err != nil {
		return err
	}
	c.BastionKeyContent = bastionKey

	return nil
}

func (c *Config) getHostFromDB() (*model.Host, *model.Bastion, error) {
	if !config.GlobalConfig.UsedDB {
		return nil, nil, nil
	}
	host, err := model.QueryHostByCondition(c.Condition)
	if err != nil || host == nil {
		return nil, nil, err
	}
	if host.BastionIP == "" {
		return host, nil, err
	}
	bastion, err := model.QueryBastionByCondition(map[string]string{"connect_ip": host.BastionIP})
	if err != nil {
		return nil, nil, err
	}
	return host, bastion, nil
}

func readSSHKey(file string) (string, error) {
	readFile, err := ioutil.ReadFile(file)
	if err != nil {
		return "", err
	}
	return string(readFile), nil
}

func replaceHomeDir(file string) string {
	if strings.HasPrefix(file, "~") {
		return strings.Replace(file, "~", os.Getenv("HOME"), 1)
	}
	return file
}
